import os
import pickle
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
from CIFAR10_GAN import *


print(torch.cuda.is_available())  # Should return True if CUDA is available
print(torch.cuda.device_count())  # Number of available GPUs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator =  generator.to(device)
discriminator  = discriminator.to(device)

new_path = "/users/eval/discriminative_approach"  # Replace with your desired path
os.chdir(new_path)
current_path = os.getcwd()
print("Current Path:", current_path)

##-----------------data loading-----------------##

# Define the transformations: convert images to PyTorch tensors and normalize them
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize RGB channels
])

# Download and load the CIFAR-10 training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

# Download and load the CIFAR-10 test dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

#-----------------dataset preview-----------------##

# dataset size
print(f"Number of training samples: {len(trainset)}")
print(f"Number of testing samples: {len(testset)}")

# preview the first image
image, label = trainset[0]
print(f"Image shape: {image.shape}")
print(f"Label: {label}")

# CIFAR-10 class namesclass names
classes = trainset.classes
print(f"Class names: {classes}")

print(f"Data type of image: {type(image)}")
print(f"Shape of image tensor: {image.shape}")

# Get a batch of data
dataiter = iter(trainloader)
images, labels = next(dataiter)

# Preview the shape of the batch
print(f"Batch image tensor shape: {images.shape}")
print(f"Batch label tensor shape: {labels.shape}")


##-----------------GAN model training process-----------------##

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

adversarial_loss = nn.BCELoss().to(device)

for epoch in range(501):  
    for i, (imgs, _) in enumerate(trainloader):
        imgs = imgs.to(device)

        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1).to(device)
        fake = torch.zeros(imgs.size(0), 1).to(device)

        # Configure input
        # real_imgs = imgs.view(imgs.size(0), -1)
        real_imgs = imgs
        z = torch.randn(imgs.size(0), 100).to(device)

        # Train Generator
        optimizer_G.zero_grad()
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # Print training stats every 100 batches
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{501}] Batch {i}/{len(trainloader)} \
                  Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

    # Save models at specified epochs
    if epoch in [10, 20, 50, 100, 300, 500]:
        torch.save(generator.state_dict(), f'state_dict_gene_full_{epoch}.pt')
        torch.save(discriminator.state_dict(), f'state_dict_disc_full_{epoch}.pt')
        print(f"epoch {epoch} has been saved")